from glob import glob
import os
import albumentations as A
import torch, numpy as np
from PIL import Image

if __name__=="__main__":
    import torch
    import pytorch_lightning as pl
    ckpt=torch.load('logs/vq-f8-n256/CLIPpreVQ_ViT-B-16/2023-03-01T19-19-25_CLIPpreVQ_ViT-B-16/checkpoints/last.ckpt')
    print(ckpt.keys())
    for k,v in ckpt.items():
        if k in ['state_dict','optimizer_states',]:
            continue
        print(f'{k}:{v}')
    # print(ckpt['state_dict'].keys())
    # print(list(ckpt['optimizer_states'])[1]['param_groups'])
    # print(list(ckpt['optimizer_states'])[1]['state'][0]['exp_avg'].shape)
    
    # path='data/ImageNet/ILSVRC2012_img_val/ILSVRC2012_val_00000478.JPEG'
    # img = Image.open(path)
    # print(img)